{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acab479f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from accelerate.logging import get_logger\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers.utils import check_min_version\n",
    "\n",
    "from peft import PeftModel\n",
    "\n",
    "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n",
    "check_min_version(\"0.10.0.dev0\")\n",
    "\n",
    "logger = get_logger(__name__)\n",
    "\n",
    "MODEL_NAME = \"stabilityai/stable-diffusion-2-1\"\n",
    "# MODEL_NAME=\"runwayml/stable-diffusion-v1-5\"\n",
    "\n",
    "PEFT_TYPE=\"boft\"\n",
    "BLOCK_NUM=8\n",
    "BLOCK_SIZE=0\n",
    "N_BUTTERFLY_FACTOR=1\n",
    "SELECTED_SUBJECT=\"backpack\"\n",
    "EPOCH_IDX = 200\n",
    "\n",
    "PROJECT_NAME=f\"dreambooth_{PEFT_TYPE}\"\n",
    "RUN_NAME=f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}\"\n",
    "OUTPUT_DIR=f\"./data/output/{PEFT_TYPE}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06cfd506",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_boft_sd_pipeline(\n",
    "    ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"cuda\", adapter_name=\"default\"\n",
    "):\n",
    "\n",
    "    if base_model_name_or_path is None:\n",
    "        raise ValueError(\"Please specify the base model name or path\")\n",
    "\n",
    "    pipe = StableDiffusionPipeline.from_pretrained(\n",
    "        base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False\n",
    "    ).to(device)\n",
    "    \n",
    "    load_adapter(pipe, ckpt_dir, epoch, adapter_name)\n",
    "\n",
    "    if dtype in (torch.float16, torch.bfloat16):\n",
    "        pipe.unet.half()\n",
    "        pipe.text_encoder.half()\n",
    "\n",
    "    pipe.to(device)\n",
    "    return pipe\n",
    "\n",
    "\n",
    "def load_adapter(pipe, ckpt_dir, epoch, adapter_name=\"default\"):\n",
    "    \n",
    "    unet_sub_dir = os.path.join(ckpt_dir, f\"unet/{epoch}\", adapter_name)\n",
    "    text_encoder_sub_dir = os.path.join(ckpt_dir, f\"text_encoder/{epoch}\", adapter_name)\n",
    "    \n",
    "    if isinstance(pipe.unet, PeftModel):\n",
    "        pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)\n",
    "    else:\n",
    "        pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n",
    "        \n",
    "    if os.path.exists(text_encoder_sub_dir):\n",
    "        if isinstance(pipe.text_encoder, PeftModel):\n",
    "            pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)\n",
    "        else:\n",
    "            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)\n",
    "    \n",
    "\n",
    "def set_adapter(pipe, adapter_name):\n",
    "    pipe.unet.set_adapter(adapter_name)\n",
    "    if isinstance(pipe.text_encoder, PeftModel):\n",
    "        pipe.text_encoder.set_adapter(adapter_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98a0d8ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"a photo of sks backpack on a wooden floor\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e888d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "pipe = get_boft_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1c1a1c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a1aafdf-8cf7-4e47-9471-26478034245e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load and reset another adapter\n",
    "# WARNING: requires training DreamBooth with `boft_bias=None`\n",
    "\n",
    "SELECTED_SUBJECT=\"dog\"\n",
    "EPOCH_IDX = 200\n",
    "RUN_NAME=f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}\"\n",
    "\n",
    "load_adapter(pipe, OUTPUT_DIR, epoch=EPOCH_IDX, adapter_name=RUN_NAME)\n",
    "set_adapter(pipe, adapter_name=RUN_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7091ad0-2005-4528-afc1-4f9d70a9a535",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "prompt = \"a photo of sks dog running on the beach\"\n",
    "negative_prompt = \"low quality, blurry, unfinished\"\n",
    "image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f534eca2-94a4-432b-b092-7149ac44b12f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:peft] *",
   "language": "python",
   "name": "conda-env-peft-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
